# -*- coding: utf-8 -*-
# type: ignore
from copy import deepcopy
from datetime import datetime
from typing import List, Dict, Optional, Any, Literal, TypeAlias, Annotated
from typing import Union

try:
    from typing import Self
except ImportError:
    from typing_extensions import Self
from uuid import uuid4

from pydantic import BaseModel, Field, field_validator
from openai.types.chat import ChatCompletionChunk


class MessageType:
    MESSAGE = "message"
    FUNCTION_CALL = "function_call"
    FUNCTION_CALL_OUTPUT = "function_call_output"
    PLUGIN_CALL = "plugin_call"
    PLUGIN_CALL_OUTPUT = "plugin_call_output"
    COMPONENT_CALL = "component_call"
    COMPONENT_CALL_OUTPUT = "component_call_output"
    MCP_LIST_TOOLS = "mcp_list_tools"
    MCP_APPROVAL_REQUEST = "mcp_approval_request"
    MCP_TOOL_CALL = "mcp_call"
    MCP_APPROVAL_RESPONSE = "mcp_approval_response"
    HEARTBEAT = "heartbeat"
    ERROR = "error"

    @classmethod
    def all_values(cls):
        """return all constants values in MessageType"""
        return [
            value
            for name, value in vars(cls).items()
            if not name.startswith("_") and isinstance(value, str)
        ]


class ContentType:
    TEXT = "text"
    DATA = "data"
    IMAGE = "image"
    AUDIO = "audio"


class Role:
    ASSISTANT = "assistant"
    USER = "user"
    SYSTEM = "system"
    TOOL = "tool"


class RunStatus:
    """
    Enum class for agent event message.
    """

    Created = "created"
    InProgress = "in_progress"
    Completed = "completed"
    Canceled = "canceled"
    Failed = "failed"
    Rejected = "rejected"
    Unknown = "unknown"


class FunctionParameters(BaseModel):
    type: str
    """The type of the parameters object. Must be `object`."""

    properties: Dict[str, Any]
    """The properties of the parameters object."""

    required: Optional[List[str]]
    """The names of the required properties."""


class FunctionTool(BaseModel):
    """
    Model class for message tool.
    """

    name: str
    """The name of the function to be called. """

    description: str
    """A description of what the function does, used by the model to choose
    when and how to call the function.
    """

    parameters: Union[FunctionParameters, Dict[str, Any]]
    """The parameters the functions accepts, described as a JSON Schema object.

    """


class Tool(BaseModel):
    """
    Model class for assistant message tool call.
    """

    type: Optional[str] = "function"
    """The type of the tool. Currently, only `function` is supported."""

    function: Optional[FunctionTool] = None
    """The function that the model called."""


class FunctionCall(BaseModel):
    """
    Model class for assistant prompt message tool call function.
    """

    call_id: Optional[str] = None
    """The ID of the tool call."""

    name: Optional[str] = None
    """The name of the function to call."""

    arguments: Optional[str] = None
    """The arguments to call the function with, as generated by the model in
    JSON format.

    Note that the model does not always generate valid JSON, and may
    hallucinate  parameters not defined by your function schema. Validate
    the arguments in your code before calling your function.
    """


class FunctionCallOutput(BaseModel):
    """
    Model class for assistant prompt message tool call function.
    """

    call_id: str
    """The ID of the tool call."""

    output: str
    """The result of the function."""


class Error(BaseModel):
    code: str
    """The error code of the message."""

    message: str
    """The error message of the message."""


class Event(BaseModel):
    sequence_number: Optional[int] = None
    """sequence number of event"""

    object: str
    """The identity of the content part."""

    status: Optional[str] = None
    """The status of the message. in_progress, completed, or incomplete"""

    error: Optional[Error] = None
    """response error for output"""

    def created(self) -> Self:
        """
        Set the message status to 'created'.
        """
        self.status = RunStatus.Created
        return self

    def in_progress(self) -> Self:
        """
        Set the message status to 'in_progress'.
        """
        self.status = RunStatus.InProgress
        return self

    def completed(self) -> Self:
        """
        Set the message status to 'completed'.
        """
        self.status = RunStatus.Completed
        return self

    def failed(self, error: Error) -> Self:
        """
        Set the message status to 'failed'.
        """
        self.status = RunStatus.Failed
        self.error = error
        return self

    def rejected(self) -> Self:
        """
        Set the message status to 'rejected'.
        """
        self.status = RunStatus.Rejected
        return self

    def canceled(self) -> Self:
        """
        Set the message status to 'canceled'.
        """
        self.status = RunStatus.Canceled
        return self


class Content(Event):
    type: str
    """The type of the content part."""

    object: str = "content"
    """The identity of the content part."""

    index: Optional[int] = None
    """the content index in message's content list"""

    delta: Optional[bool] = False
    """Whether this content is a delta."""

    msg_id: Optional[str] = None
    """message unique id"""

    @staticmethod
    def from_chat_completion_chunk(
        chunk: ChatCompletionChunk,
        index: Optional[int] = None,
    ) -> Optional[Union["TextContent", "DataContent", "ImageContent"]]:
        if not chunk.choices:
            return None

        choice = chunk.choices[0]
        if choice.delta.content is not None:
            return TextContent(
                delta=True,
                text=choice.delta.content,
                index=index,
            )
        elif choice.delta.tool_calls:
            # TODO: support multiple tool calls output
            tool_call = choice.delta.tool_calls[0]
            if tool_call.function is not None:
                return DataContent(
                    delta=True,
                    data={
                        "call_id": tool_call.id,
                        "name": tool_call.function.name,
                        "arguments": tool_call.function.arguments,
                    },
                    index=index,
                )
            else:
                return None
        else:
            return None


class ImageContent(Content):
    type: Literal[ContentType.IMAGE] = ContentType.IMAGE
    """The type of the content part."""

    image_url: Optional[str] = None
    """The image URL details."""


class TextContent(Content):
    type: Literal[ContentType.TEXT] = ContentType.TEXT
    """The type of the content part."""

    text: Optional[str] = None
    """The text content."""


class DataContent(Content):
    type: Literal[ContentType.DATA] = ContentType.DATA
    """The type of the content part."""

    data: Optional[Dict] = None
    """The data content."""


AgentRole: TypeAlias = Literal[
    Role.ASSISTANT,
    Role.SYSTEM,
    Role.USER,
    Role.TOOL,
]


AgentContent = Annotated[
    Union[TextContent, ImageContent, DataContent],
    Field(discriminator="type"),
]


class Message(Event):
    id: str = Field(default_factory=lambda: "msg_" + str(uuid4()))
    """message unique id"""

    object: str = "message"
    """message identity"""

    type: str = "message"
    """The type of the message."""

    status: str = RunStatus.Created
    """The status of the message. in_progress, completed, or incomplete"""

    role: Optional[AgentRole] = None
    """The role of the messages author, should be in `user`,`system`,
    'assistant'."""

    content: Optional[List[AgentContent]] = None
    """The contents of the message."""

    code: Optional[str] = None
    """The error code of the message."""

    message: Optional[str] = None
    """The error message of the message."""

    usage: Optional[Dict] = None
    """response usage for output"""

    @staticmethod
    def from_openai_message(message: Union[BaseModel, dict]) -> "Message":
        """Create a message object from an openai message."""

        # in case message is a Message object
        if isinstance(message, Message):
            return message

        # make sure operation on dict object
        if isinstance(message, BaseModel):
            message = message.model_dump()

        # in case message is a Message format dict
        if "type" in message and message["type"] in MessageType.all_values():
            return Message(**message)

        # handle message in openai message format
        if message["role"] == Role.ASSISTANT and "tool_calls" in message:
            _content_list = []
            for tool_call in message["tool_calls"]:
                _content = DataContent(
                    data=FunctionCall(
                        call_id=tool_call["id"],
                        name=tool_call["function"]["name"],
                        arguments=tool_call["function"]["arguments"],
                    ).model_dump(),
                )
                _content_list.append(_content)
            _message = Message(
                type=MessageType.FUNCTION_CALL,
                content=_content_list,
            )
        elif message["role"] == Role.TOOL:
            _content = DataContent(
                data=FunctionCallOutput(
                    call_id=message["tool_call_id"],
                    output=message["content"],
                ).model_dump(),
            )
            _message = Message(
                type=MessageType.FUNCTION_CALL_OUTPUT,
                content=[_content],
            )
        # mainly focus on matching content
        elif isinstance(message["content"], str):
            _content = TextContent(text=message["content"])
            _message = Message(
                type=MessageType.MESSAGE,
                role=message["role"],
                content=[_content],
            )
        else:
            _content_list = []
            for content in message["content"]:
                if content["type"] == "image_url":
                    _content = ImageContent(
                        image_url=content["image_url"]["url"],
                    )
                elif content["type"] == "text":
                    _content = TextContent(text=content["text"])
                else:
                    _content = DataContent(data=content["text"])
                _content_list.append(_content)
            _message = Message(
                type=MessageType.MESSAGE,
                role=message["role"],
                content=_content_list,
            )
        return _message

    def get_text_content(self) -> Optional[str]:
        """
        Extract the first text content from the message.

        :return:
            First text string found in the content, or None if no text content
        """
        if self.content is None:
            return None

        for item in self.content:
            if isinstance(item, TextContent):
                return item.text
        return None

    def get_image_content(self) -> List[str]:
        """
        Extract all image content (URLs or base64 data) from the message.

        :return:
            List of image URLs or base64 encoded strings found in the content
        """
        images = []

        if self.content is None:
            return images

        for item in self.content:
            if isinstance(item, ImageContent):
                images.append(item.image_url)
        return images

    def get_audio_content(self) -> List[str]:
        """
        Extract all audio content (URLs or base64 data) from the message.

        :return:
            List of audio URLs or base64 encoded strings found in the content
        """
        audios = []

        if self.content is None:
            return audios

        for item in self.content:
            if hasattr(item, "type"):
                if item.type == "input_audio" and hasattr(
                    item,
                    "input_audio",
                ):
                    if hasattr(item.input_audio, "data"):
                        audios.append(item.input_audio.data)
                    elif hasattr(item.input_audio, "base64_data"):
                        # Construct data URL for audio
                        format_type = getattr(
                            item.input_audio,
                            "format",
                            "mp3",
                        )
                        audios.append(
                            f"data:{format_type};base64,"
                            f"{item.input_audio.base64_data}",
                        )

        return audios

    def add_delta_content(
        self,
        new_content: Union[TextContent, ImageContent, DataContent],
    ):
        self.content = self.content or []

        # new content
        if new_content.index is None:
            copy = deepcopy(new_content)
            copy.delta = None
            copy.index = None
            copy.msg_id = None
            self.content.append(copy)

            new_content.index = len(self.content) - 1
            new_content.msg_id = self.id
            new_content.in_progress()
            return new_content

        # delta content
        if new_content.delta is True:
            # append the content
            pre_content = self.content[new_content.index]
            _type = pre_content.type

            # append text
            if _type == ContentType.TEXT:
                pre_content.text += new_content.text

            # append image_url
            if _type == ContentType.IMAGE:
                pre_content.image_url += new_content.image_url

            # append data
            if _type == ContentType.DATA:
                for key in new_content.data:
                    if (
                        key in pre_content.data
                        and isinstance(pre_content.data[key], (list, str))
                        and isinstance(
                            new_content.data[key],
                            type(pre_content.data[key]),
                        )
                    ):
                        if isinstance(pre_content.data[key], list):
                            pre_content.data[key].extend(new_content.data[key])
                        elif isinstance(pre_content.data[key], str):
                            pre_content.data[key] += new_content.data[key]
            new_content.msg_id = self.id
            new_content.in_progress()
            return new_content

        return None

    def content_completed(self, content_index: int):
        if self.content is None:
            return None
        if content_index >= len(self.content):
            return None
        else:
            content = self.content[content_index]
            new_content = deepcopy(content)
            new_content.delta = False
            new_content.index = content_index
            new_content.msg_id = self.id
            new_content.completed()
            return new_content

    def add_content(
        self,
        new_content: Union[TextContent, ImageContent, DataContent],
    ):
        self.content = self.content or []

        # new content
        if new_content.index is None:
            copy = deepcopy(new_content)
            self.content.append(copy)

            new_content.index = len(self.content) - 1
            new_content.msg_id = self.id
            new_content.completed()
            return new_content

        return None


class BaseRequest(BaseModel):
    """agent request"""

    input: List[Message]
    """
    input messages
    """

    stream: bool = True
    """If set, partial message deltas will be sent, like in ChatGPT. """


class AgentRequest(BaseRequest):
    """agent request"""

    model: Optional[str] = None
    """
    model id
    """

    top_p: Optional[float] = None
    """Nucleus sampling, between (0, 1.0],  where the model considers the
    results of the tokens with top_p probability  mass.

    So 0.1 means only the tokens comprising the top 10% probability mass are
    considered.

    We generally recommend altering this or `temperature` but not both.
    """

    temperature: Optional[float] = None
    """What sampling temperature to use, between 0 and 2.

    Higher values like 0.8 will make the output more random, while lower values
    like 0.2 will make it more focused and deterministic.

    We generally recommend altering this or `top_p` but not both.
    """

    frequency_penalty: Optional[float] = None
    """Positive values penalize new tokens based on their existing frequency in
    the text so far, decreasing the model's likelihood to repeat the same line
    verbatim.

    """

    presence_penalty: Optional[float] = None
    """Number between -2.0 and 2.0.

    Positive values penalize new tokens based on whether they appear in the
    text so far, increasing the model's likelihood to talk about new topics.

    """

    max_tokens: Optional[int] = None
    """The maximum number of [tokens](/tokenizer) that can be generated in the
    chat completion.

    The total length of input tokens and generated tokens is limited by the
    model's context length.
    """

    stop: Optional[Union[Optional[str], List[str]]] = None
    """Up to 4 sequences where the API will stop generating further tokens."""

    n: Optional[int] = Field(default=1, ge=1, le=5)
    """How many chat completion choices to generate for each input message.

    Note that you will be charged based on the number of generated tokens
    across all of the choices. Keep `n` as `1` to minimize costs.
    """

    seed: Optional[int] = None
    """If specified, system will make a best effort to sample
    deterministically, such that repeated requests with the same `seed` and
    parameters should return the same result.
    """

    tools: Optional[List[Union[Tool, Dict]]] = None
    """
    tool call list
    """

    session_id: Optional[str] = None
    """conversation id for dialog"""

    response_id: Optional[str] = None
    """response unique id"""


class BaseResponse(Event):
    id: Optional[str] = Field(
        default_factory=lambda: "response_"
        + str(
            uuid4(),
        ),
    )
    """response unique id"""

    @field_validator("id", mode="before")
    @classmethod
    def validate_id(cls, v):
        if v is None:
            return "response_" + str(uuid4())
        return v

    object: str = "response"
    """response identity"""

    status: str = RunStatus.Created
    """response run status"""

    created_at: int = int(datetime.now().timestamp())
    """request start time"""

    completed_at: Optional[int] = None
    """request completed time"""

    output: Optional[List[Message]] = None
    """response data for output"""

    usage: Optional[Dict] = None
    """response usage for output"""

    def add_new_message(self, message: Message):
        self.output = self.output or []
        self.output.append(message)


class AgentResponse(BaseResponse):
    """agent response"""

    session_id: Optional[str] = None
    """conversation id for dialog"""


def convert_to_openai_tool_call(function: FunctionCall):
    return {
        "id": function.get("call_id", None),
        "type": "function",
        "function": {
            "name": function.get("name", None),
            "arguments": function.get("arguments", None),
        },
    }


def convert_to_openai_messages(messages: List[Message]) -> List[Dict]:
    """
    Convert a generic message protocol to a model-specific protocol.
    Args:
        messages: Original list of messages
    Returns:
        list: Message format required by the model
    """
    converted = []
    for msg in messages:
        if MessageType.MESSAGE == msg.type:
            converted.append(
                {
                    "role": msg.role,
                    "content": [c.model_dump() for c in msg.content],
                },
            )

        if MessageType.FUNCTION_CALL == msg.type:
            converted.append(
                {
                    "role": Role.ASSISTANT,
                    "tool_calls": [
                        convert_to_openai_tool_call(c.data)
                        for c in msg.content
                    ],
                },
            )

        if MessageType.FUNCTION_CALL_OUTPUT == msg.type:
            for function_call_output in msg.content:
                converted.append(
                    {
                        "role": "tool",
                        "tool_call_id": function_call_output.data.get(
                            "call_id",
                        ),
                        "content": function_call_output.data.get("output"),
                    },
                )
    return converted


def convert_to_openai_tools(tools: List[Union[Tool, Dict]]) -> Optional[list]:
    if not tools:
        return None
    return [
        tool.model_dump() if isinstance(tool, Tool) else tool for tool in tools
    ]
